from keras.datasets import mnist
import matplotlib.pyplot as plt

(X_train, y_train), (X_test, y_test) = mnist.load_data()

num_train_images = X_train.shape[0]
num_test_images = X_test.shape[0]
image_height = X_train.shape[1]
image_width = X_train.shape [2]
print("Shape: " + str(X_train.shape))
print("Training images: " + str(num_train_images))
print("Image height: " + str(image_height))
print("Image width: " + str(image_width))

plt.imshow(X_train[2])

fig, axs = plt.subplots(1,12, figsize=(17,6))
for i in range(12):
    axs[i].imshow(X_train[i], cmap = plt.get_cmap('gray'))
    axs[i].axis('off')